import yaml
import cv2
import torch
import os
import fire


def pth_write(cfg_file):

    data_dicts = yaml.full_load(open(cfg_file, encoding='utf-8'))
    pth_file = os.path.join(data_dicts['output_folder'], 'model_final.pth')
    net = torch.load(pth_file, map_location=torch.device('cpu'))
    net['info'] = {}
    for k, v in data_dicts['info'].items():
        print(k, v)
        net['info'][k] = v

    net['info']['config_yaml'] = os.path.join(
        data_dicts['output_folder'], 'config.yaml')
    net['info']['label_map'] = os.path.join(
        data_dicts['output_folder'], 'labels.json')
    torch.save(net, pth_file)


def main():
    cfg_file = r'./config/train_config.yaml'
    pth_write(cfg_file)


if __name__ == '__main__':
    main()
